import numpy as np

from toolkit.utils import agrid, markov_rouwenhorst, broyden_solver
from toolkit.simple_block import simple
from toolkit.jacobian import get_G
from toolkit.nonlinear import td_solve

import separable as sep
import sticky_price as sp


'''Part 1: Embed HA block in macro model'''


def taxes(T_rule, pi_e, T0, T1, tax0, tax1, tax_rule, w):
    T = T0 + T1 / np.sum(pi_e * T_rule) * T_rule
    tax = tax0 + tax1 * tax_rule
    atw = (1 - tax) * w
    return T, tax, atw


household = sep.household.attach_hetinput(taxes)


'''Part 2: simple blocks not already introduced in sticky price model'''


@simple
def wage_setting(UCE, L, pi, muw, kappaw, nu, vphi, beta):
    wnkpc = kappaw * L * (vphi * L**nu - UCE / muw) + beta * np.log(1 + pi(+1)) - np.log(1 + pi)
    return wnkpc


@simple
def dividend(Y, w, L):
    div = Y - w * L
    return div


@simple
def mkt_clearing(p, A, B, Y, C, G):
    asset_mkt = p + B - A
    goods_mkt = C + G - Y
    return asset_mkt, goods_mkt


'''Part 3: model variants'''


def sw_benchmark(ss, T, linear=True, dG=None):
    """Compute GE Jacobian of benchmark sticky-wage model."""
    # set up DAG
    block_list = [household, mkt_clearing, wage_setting, dividend,
                  sp.production, sp.debt_policy, sp.gov_budget1, sp.arbitrage, sp.interest_rates, sp.real_bonds,
                  sp.constant_r, sp.monetary, sp.gov_budget2]
    exogenous = ['G']
    unknowns = ['pi', 'L', 'tax1']
    targets = ['asset_mkt', 'wnkpc', 'gov_budget_res']

    if not linear:
        if dG is None:
            ValueError('If you want a nonlinear solution, specify the shock dG.')

    if linear:
        return get_G(block_list, exogenous, unknowns, targets, T=T, ss=ss)
    else:
        return td_solve(ss, block_list, unknowns, targets, returnindividual=True, G=ss['G']+dG)


def sw_taylor(ss, T):
    """Compute GE Jacobian of sticky-wage model with Taylor rule."""
    # set up DAG
    block_list = [household, mkt_clearing, wage_setting, dividend,
                  sp.production, sp.debt_policy, sp.gov_budget1, sp.arbitrage, sp.interest_rates, sp.real_bonds,
                  sp.monetary, sp.taylor_rule, sp.gov_budget2]
    exogenous = ['G']
    unknowns = ['pi', 'L', 'tax1']
    targets = ['asset_mkt', 'wnkpc', 'gov_budget_res']

    return get_G(block_list, exogenous, unknowns, targets, T=T, ss=ss)


def sw_peg(ss, T, k=12):
    """Compute GE Jacobian of sticky-wage model with nominal peg for k years followed by Taylor rule."""
    # compute Jacobian of Taylor rule
    J_taylor = sp.taylor_rule.jac(ss, T)

    # no response to inflation in first k periods
    J_taylor['i']['pi'][:k, :] = 0

    # set up DAG
    block_list = [household, mkt_clearing, wage_setting, dividend,
                  sp.production, sp.debt_policy, sp.gov_budget1, sp.arbitrage, sp.interest_rates, sp.real_bonds,
                  sp.monetary, J_taylor, sp.gov_budget2]
    exogenous = ['G']
    unknowns = ['pi', 'L', 'tax1']
    targets = ['asset_mkt', 'wnkpc', 'gov_budget_res']

    return get_G(block_list, exogenous, unknowns, targets, T=T, ss=ss)


'''Part 4: Calibration'''


def calibrate(betamax_guess=0.98, vphi_guess=0.88, sigma=2, nu=2, r=0.005, bmin=0, bmax=0.2, phi=1.25, B=0.55*4,
              rho_B=0.9, p=0.85*4, tax0=0.334, T_w=0.143, mup=7 / 6, muw=1.1, kappap=0.01, rho_e=0.966, sd_e=0.92,
              nS=11, amax=1000, amin=0, nA=500, noisy=False, pi_beta=None, Pi_beta=None, MPC_target=0.25, maxit=10,
              MPC_tol=0.001, T_rule=None, tax_rule=None):
    """Solve steady state of full GE model.

    Inner loop: calibrate (beta, vphi, Z, F) to hit targets (r, p, L=1, Y=1).
    Outer loop: calibrate range of beta heterogeneity to hit average MPC.
    """
    # set up grid
    a_grid, e_grid, pi_e, Pi = grids(amax, nA, amin, rho_e, sd_e, nS, pi_beta, Pi_beta)

    # tax and transfer rule, scale does not matter, will be normalized anyway
    if T_rule is None:
        T_rule = np.ones_like(e_grid)
    if tax_rule is None:
        tax_rule = np.ones_like(e_grid)
    assert len(T_rule) == len(tax_rule) == len(e_grid), 'Incidence rule is inconsistent with income grid.'

    # solve analytically what we can
    Z = mup * (1 - p*r)
    F = Z - 1
    w = Z / mup
    div = 1 - w
    pshare = p / (p + B)
    T0 = T_w * w
    G = tax0 * w - r * B - T0  # uses L=1
    tax1, T1 = 0, 0
    T, tax, atw = taxes(T_rule, pi_e, T0, T1, tax0, tax1, tax_rule, w)

    # initialize ss dict
    ss = {'B': B, 'Y': 1, 'mup': mup, 'muw': muw, 'L': 1, 'pi': 0, 'div': div, 'phi': phi, 're': r, 'rb': r,
          'G': G, 'G_ss': G, 'B_ss': B, 'r': r, 'rstar': r, 'rho_B': rho_B, 'wage_adjustment_cost': 0,
          'pshare': pshare, 'p': p, 'Z': Z, 'F': F, 'nu': nu, 'i': r, 'tax_all': tax0, 'REV_ss': w * tax0}

    # initialize guess for policy function iteration
    coh = (1 + r) * a_grid[np.newaxis, :] + atw[:, np.newaxis] * e_grid[:, np.newaxis] + T[:, np.newaxis]
    uc = (0.2 * coh) ** (-sigma)

    # residual function
    def res(betarange0):
        # residual function conditional on betarange
        def res_inner(x):
            # guesses with local names
            betamax0, vphi0 = x

            # update exogenous processes
            betas0 = np.array([betamax0 - betarange0, betamax0])
            beta_grid0 = np.kron(betas0, np.ones(nS))

            if betamax0 > 0.999999/(1 + r) or vphi0 < 0.001:
                raise ValueError('Clearly invalid inputs')

            # solve HH problem
            out = household.ss(uc=uc, Pi=Pi, a_grid=a_grid, e_grid=e_grid, beta_grid=beta_grid0, L=1, w=w,
                               sigma=sigma, T0=T0, T1=T1, T_rule=T_rule, tax0=tax0, tax_rule=tax_rule,
                               tax1=tax1, pi_e=pi_e, rpost=r, rsub=r)

            asset_mkt = out['A'] - B - p
            wnkpc = out['UCE'] / muw - vphi0

            # save this into ss
            ss.update({**out, 'betamax': betamax0, 'betarange': betarange0, 'vphi': vphi0})

            return np.array([asset_mkt, wnkpc])

        # solve for params conditional on betarange
        print(f'Starting inner loop for betarange = {betarange0}')
        x0 = np.array([betamax_guess, vphi_guess])
        _ = broyden_solver(res_inner, x0, noisy=noisy)

        # MPCs
        mpc = sep.mpcs(ss['c'], ss['a'], a_grid, r)
        MPC = np.vdot(ss['D'], mpc)
        MPC_resid = MPC - MPC_target
        print(f'MPC= {MPC:0.3f}')
        print('------------------------------------------')

        ss.update({'mpc': mpc, 'MPC': MPC})

        return MPC_resid

    # bisect on betarange
    for it in range(maxit):
        bmid = (bmin + bmax) / 2
        error = res(bmid)
        if error > MPC_tol:
            # MPC too high
            bmax = bmid
        elif error < - MPC_tol:
            # MPC too low
            bmin = bmid
        else:
            break

    # scale slope of phillips curve to be equivalent to sticky price model
    kappaw = kappap * ss['vphi'] / mup

    # check Walras's law
    walras = 1 - ss['C'] - G

    # average beta
    ss.update({'beta': ss['betamax'] - bmid/2, 'kappaw': kappaw, 'walras': walras})

    return ss


def grids(amax, nA, amin, rho_e, sd_e, nS, pi_beta=None, Pi_beta=None):
    # assets
    a_grid = agrid(amax=amax, n=nA, amin=amin)

    # beta and productivity
    if (pi_beta is None) or (Pi_beta is None):
        pi_beta = np.array([0.5, 0.5])
        Pi_beta = np.eye(2)
    e, pi_e, Pi_e = markov_rouwenhorst(rho=rho_e, sigma=sd_e, N=nS)
    e_grid = np.kron(np.ones_like(pi_beta), e)
    pi_e = np.kron(pi_beta, pi_e)
    Pi = np.kron(Pi_beta, Pi_e)

    return a_grid, e_grid, pi_e, Pi
